from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM
from typing import Dict, Any, Union
from trl.data_utils import maybe_apply_chat_template
import torch
import re
import numpy as np
import sys
# 添加包含 utils_aitw.py 的目录到 sys.path
sys.path.append('./vlm_modules/')
from utils_aitw import pred2json_post, action2json, check_actions_match

from vlm_modules.vlm_module import VLMBaseModule

class Qwen2VLModule(VLMBaseModule):
    def __init__(self):
        super().__init__()

    def get_vlm_key(self):
        return "qwen"

    def get_model_class(self, model_id: str, model_init_kwargs: dict):
        if "Qwen2-VL" in model_id:
            model_cls = Qwen2VLForConditionalGeneration
        elif "Qwen2.5-VL" in model_id:
            model_cls = Qwen2_5_VLForConditionalGeneration
        elif "Qwen-VL" in model_id:
            model_cls = AutoModelForCausalLM
        else:
            raise ValueError(f"Unsupported model: {model_id}")
        return model_cls

    def post_model_init(self, model, processing_class):
        pass

    def get_processing_class(self):
        return AutoProcessor

    def get_vision_modules_keywords(self):  
        return ['visual']

    def get_custom_multimodal_keywords(self):
        return ['pixel_values', 'image_grid_thw']

    def get_non_generate_params(self):
        return []

    def get_custom_processing_keywords(self):
        return ['max_pixels', 'min_pixels']

    def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
        prompts_text = [maybe_apply_chat_template(example, processing_class)["prompt"] for example in inputs]
        return prompts_text

    def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
        # FIXME
        # This could only process pure-multimodal or pure-text inputs
        if len(images) > 0:
            prompt_inputs = processing_class(
                text=prompts_text,
                images=images,
                return_tensors=return_tensors,
                padding=padding,
                padding_side=padding_side,
                add_special_tokens=add_special_tokens)
        else:
            prompt_inputs = processing_class(
                text=prompts_text,
                return_tensors=return_tensors,
                padding=padding,
                padding_side=padding_side,
                add_special_tokens=add_special_tokens)
        return prompt_inputs

    @staticmethod
    def get_question_template(task_type: str):
        return "{Question}"
        # match task_type:
        #     case "rec":
        #         return "{Question}"
        #     case _:
        #         return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."

    @staticmethod
    def format_reward_rec(completions, **kwargs):
        """Check if the Qwen model output matches a specific format."""
        import re
        # pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
        pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
        return [0.5 if match else 0.0 for match in matches]

    def format_reward_think(completions, **kwargs):
        """Check if the Qwen model output matches a specific format."""
        import re
        patterns = [r"<Task Analysis>.*?</Task Analysis>\s", r"<Progress Estimation>.*?</Progress Estimation>\s", r"<Decesion Making>.*?</Decesion Making>\s", r"<History Summary>.*?</History Summary>\s"]
        rewards = [0 for _ in range(len(completions))]
        for pattern in patterns:
            completion_contents = [completion[0]["content"] for completion in completions]
            matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
            for idx, match in enumerate(matches):
                rewards[idx] += 0.1 if match else 0
        return rewards

    def format_reward(completions, **kwargs):
        pattern = r"<think>.*?</think>\s*<answer>.*?\[.*?{\"bbox_2d\":\s*\[\s*\d+,\s*\d+,\s*\d+,\s*\d+\s*\]\s*,\s*\"label\":\s*\".*?\"\s*}.*?\].*?</answer>"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
        return [1.0 if match else 0.0 for match in matches]

    @staticmethod
    def iou_reward(completions, solution, **kwargs):
        """Calculate IoU reward between predicted bounding box from Qwen model and ground truth bounding box."""
        import re
        import os
        from datetime import datetime
        def iou(box1, box2):
            inter_x1 = max(box1[0], box2[0])
            inter_y1 = max(box1[1], box2[1])
            inter_x2 = min(box1[2]-1, box2[2]-1)
            inter_y2 = min(box1[3]-1, box2[3]-1)
            if inter_x1 < inter_x2 and inter_y1 < inter_y2:
                inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
            else:
                inter = 0
            union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
            return float(inter)/union
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        answer_tag_pattern = r'<answer>(.*?)</answer>'
        bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
        for content, sol in zip(contents, solution):
            reward = 0.0
            # Try symbolic verification first
            try:
                content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
                if content_answer_match:
                    content_answer = content_answer_match.group(1).strip()
                    bbox_match = re.search(bbox_pattern, content_answer)
                    if bbox_match:
                        bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
                        if iou(bbox, sol) > 0.5:
                            reward = 1.0
            except Exception:
                pass  # Continue to next verification method if this fails

            rewards.append(reward)
            print(os.getenv("DEBUG_MODE"))
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                # local_rank = int(os.getenv("LOCAL_RANK", 0))
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Solution: {sol}\n")
        return rewards

    @staticmethod
    def is_location_close(loc1, loc2, threshold=0.1):
        """
        判断两个坐标是否接近。
        :param loc1: 第一个坐标，格式为 (x1, y1)
        :param loc2: 第二个坐标，格式为 (x2, y2)
        :param threshold: 允许的最大距离（默认值为 10 像素）
        :return: 如果距离小于阈值，返回 True；否则返回 False。
        """
        import math
        x1, y1 = loc1
        x2, y2 = loc2
        distance = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)  # 计算欧几里得距离
        return distance < threshold

    @staticmethod
    def pred2json(prediction):
        prediction = prediction.replace('\"', '\'')
        pattern = r"'action':\s*'(.*?)',\s*'value':\s*(None|'(.*?)'),\s*'position':\s*(None|\[([0-9.]+),\s*([0-9.]+)\])"
        match = re.search(pattern, prediction)

        if match:
            action = match.group(1)
            value = match.group(2)
            if value == 'None':
                value = None
            else:
                value = match.group(3)

            position_group = match.group(4)
            if position_group == 'None':
                position = None
            else:
                position_x = float(match.group(5))
                position_y = float(match.group(6))
                position = [position_x, position_y]

            return {
                "action": action,
                "value": value,
                "position": position
            }
        else:
            raise ValueError(f"Input string '{prediction}' doesn't match the expected format")

    @staticmethod
    def extract_answer(text):
        """
        从文本中提取 <answer> 标签内容。
        如果 <answer> 标签不存在，返回原始文本。
        """
        match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL)
        return match.group(1).strip() if match else text.strip()

    @staticmethod
    def mind2web_verify_action(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        for content, sol in zip(student_answer, ground_truth):
            reward = 0.0
            try:
                action_pred = Qwen2VLModule.pred2json(content)
                answer = Qwen2VLModule.pred2json(sol)
                # import ipdb; ipdb.set_trace()
                if action_pred["action"].lower() == answer["action"].lower():
                    # step_result["Op_match"] = True
                    click_point = action_pred["position"]
                    answer_point = answer["position"]
                    if Qwen2VLModule.is_location_close(click_point, answer_point):
                        reward = 1.0
                    else:
                        reward = 0.0
                else:
                    reward = 0.0
            except Exception as e:
                print(f"Accuracy Reward Function Error: {e}")


            rewards.append(reward)
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                os.makedirs(os.path.dirname(log_path), exist_ok=True)
                # local_rank = int(os.getenv("LOCAL_RANK", 0))
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Solution: {sol}\n")

        return rewards

    @staticmethod
    def mind2web_verify_action_type(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        for content, sol in zip(student_answer, ground_truth):
            reward = 0.0
            try:
                action_pred = Qwen2VLModule.pred2json(content)
                answer = Qwen2VLModule.pred2json(sol)
                if action_pred["action"].lower() == answer["action"].lower():
                    # if answer["action"] == 'CLICK':
                    #     reward = 0.3
                    # elif answer["action"] == 'TYPE' or answer["action"] == 'SELECT':
                    #     reward = 0.5
                    reward = 1.0
                else:
                    reward = 0.0
            except Exception as e:
                print(f"Accuracy Reward Function Error: {e}")


            rewards.append(reward)
            # if os.getenv("DEBUG_MODE") == "true":
            #     log_path = os.getenv("LOG_PATH")
            #     os.makedirs(os.path.dirname(log_path), exist_ok=True)
            #     # local_rank = int(os.getenv("LOCAL_RANK", 0))
            #     with open(log_path, "a", encoding='utf-8') as f:
            #         f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
            #         f.write(f"Content: {content}\n")
            #         f.write(f"Solution: {sol}\n")

        return rewards

    @staticmethod
    def mind2web_verify_action_position(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        for i, (content, sol) in enumerate(zip(student_answer, ground_truth)):
            reward = 0.0
            try:
                action_pred = Qwen2VLModule.pred2json(content)
                answer = Qwen2VLModule.pred2json(sol)
                # import ipdb; ipdb.set_trace()
                click_point = action_pred["position"]
                answer_point = answer["position"]

                bbox_ref = kwargs['bbox'][i]
                if (bbox_ref[0] <= click_point[0] <= bbox_ref[2]) and (bbox_ref[1] <= click_point[1] <= bbox_ref[3]):
                    reward = 1.0
                else:
                    reward = 0.0
                # if Qwen2VLModule.is_location_close(click_point, answer_point):
                #     reward = 1.0
                # else:
                #     reward = 0.0
            except Exception as e:
                print(f"Accuracy Reward Function Error: {e}")


            rewards.append(reward)
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                os.makedirs(os.path.dirname(log_path), exist_ok=True)
                # local_rank = int(os.getenv("LOCAL_RANK", 0))
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"Content: {completions[i][0]['content']}\n")
                    f.write(f"Solution: {sol}\n")

        return rewards

    @staticmethod
    def mind2web_verify_action_position_no_report(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        for i, (content, sol) in enumerate(zip(student_answer, ground_truth)):
            reward = 0.0
            try:
                action_pred = Qwen2VLModule.pred2json(content)
                answer = Qwen2VLModule.pred2json(sol)
                # import ipdb; ipdb.set_trace()
                click_point = action_pred["position"]
                answer_point = answer["position"]

                bbox_ref = kwargs['bbox'][i]
                if (bbox_ref[0] <= click_point[0] <= bbox_ref[2]) and (bbox_ref[1] <= click_point[1] <= bbox_ref[3]):
                    reward = 1.0
                else:
                    reward = 0.0
                # if Qwen2VLModule.is_location_close(click_point, answer_point):
                #     reward = 1.0
                # else:
                #     reward = 0.0
            except Exception as e:
                continue

        return rewards

    @staticmethod
    def mind2web_select_grouding_error(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        for i, (content, sol) in enumerate(zip(student_answer, ground_truth)):
            reward = 0.0
            try:
                action_pred = Qwen2VLModule.pred2json(content)
                answer = Qwen2VLModule.pred2json(sol)
                # import ipdb; ipdb.set_trace()
                click_point = action_pred["position"]
                answer_point = answer["position"]

                if not Qwen2VLModule.is_location_close(click_point, answer_point) and Qwen2VLModule.is_location_close(click_point, answer_point, threshold=0.15):
                    reward = 1.0
                else:
                    reward = 0.0
            except Exception as e:
                print(f"Accuracy Reward Function Error: {e}")
            rewards.append(reward)

        return rewards

    @staticmethod
    def mind2web_verify_action_compare_last(completions, solution, **kwargs):
        completions = kwargs['lats_completions']
        last_position_rewards = Qwen2VLModule.mind2web_verify_action_type(completions, solution, **kwargs)
        last_type_rewards = Qwen2VLModule.mind2web_verify_action_type(completions, solution, **kwargs)
        rewards_per_func = kwargs['rewards_per_func']

        rewards = []
        if sum(last_position_rewards) > rewards_per_func[:,1].sum().item():
            now_position_rewards = [0.7] * len(last_position_rewards)
        else:
            now_position_rewards = [0] * len(last_position_rewards)
        if sum(last_type_rewards) > rewards_per_func[:,0].sum().item():
            now_type_rewards = [0.3] * len(last_type_rewards)
        else:
            now_type_rewards = [0] * len(last_type_rewards)

        rewards = [posi + type for posi, type in zip(now_position_rewards, now_type_rewards)]
        return rewards

    @staticmethod
    def mind2web_verify_action_format(completions, solution, **kwargs):
        import re
        import os
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        rewards = []
        for content, sol in zip(student_answer, ground_truth):
            reward = 0.0
            try:
                action_pred = Qwen2VLModule.pred2json(content)
                reward = 0.5
            except Exception as e:
                reward = 0.0
            rewards.append(reward)

        return rewards



    @staticmethod
    def parse_action(action_str):
        """
        解析动作字符串，提取动作类型和参数。
        """
        if 'press_home' in action_str.lower():
            return 'press_home', {}
        elif 'press_back' in action_str.lower():
            return 'press_back', {}
        elif 'press_enter' in action_str.lower():
            return 'press_enter', {}
        elif 'finished' in action_str.lower():
            return 'finished', {}
        elif 'type' in action_str.lower():
            # 匹配 type 的 content
            content_match = re.search(r"type\(content=\('([^']*)'\)\)", action_str)
            if content_match:
                return 'type', {"content": content_match.group(1)}
            else:
                return None, None
        elif 'scroll' in action_str.lower():
            start_box_match = re.search(r"start_box='\((\d+),(\d+)\)'", action_str)
            end_box_match = re.search(r"end_box='\((\d+),(\d+)\)'", action_str)
            if start_box_match and end_box_match:
                return 'scroll', {
                    "start_box": (int(start_box_match.group(1)), int(start_box_match.group(2))),
                    "end_box": (int(end_box_match.group(1)), int(end_box_match.group(2))),
                }
            else:
                return None, None
        elif 'click' in action_str.lower():
            start_box_match = re.search(r"start_box=\((\d+),(\d+)\)'", action_str)
            if start_box_match:
                return 'click', {"start_box": (int(start_box_match.group(1)), int(start_box_match.group(2)))}
            else:
                return None, None
        else:
            return None, None

    @staticmethod
    def verify_action(student_action, ground_truth_action, accelerator):
        """
        验证学生动作是否与标准答案匹配。
        """
        student_action = student_action.replace('Action:', '')
        student_type, student_args = Qwen2VLModule.parse_action(student_action)
        ground_truth_type, ground_truth_args = Qwen2VLModule.parse_action(ground_truth_action)

        if student_type != ground_truth_type:
            return 0.0  # 动作类型不匹配

        if student_type == "click":
            location_result = Qwen2VLModule.is_location_close(student_args["start_box"], ground_truth_args["start_box"], threshold=10)
            if location_result:
                return  1.0  # 检查坐标是否一致
            else:
                return  0.0
        elif student_type == "type":
            type_result = (student_args["content"] == ground_truth_args["content"] or student_args["content"] in ground_truth_args["content"] or ground_truth_args["content"] in student_args["content"])
            if type_result:
                return 1.0  # 检查内容是否一致
            else:
                return  0.0
        elif student_type == "scroll":
            location_result_start = Qwen2VLModule.is_location_close(student_args["start_box"], ground_truth_args["start_box"], threshold=20)
            location_result_end = Qwen2VLModule.is_location_close(student_args["end_box"], ground_truth_args["end_box"], threshold=20)
            if location_result_start and location_result_end:
                return 1.0
            else:
                return 0.0  # 检查起始和结束坐标
        elif student_type in ["press_home", "press_back", "press_enter", "finished"]:
            return 1.0  # 无需参数，类型匹配即可
        else:
            return 0.0  # 未知动作类型

    @staticmethod
    def aitz_verify_action_type(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        for content, sol in zip(student_answer, ground_truth):
            reward = 0.0
            try:
                action_pred = Qwen2VLModule.pred2json(content)
                answer = Qwen2VLModule.pred2json(sol)
                if action_pred["action"].lower() == answer["action"].lower():
                    reward = 1.0
                else:
                    reward = 0.0
            except Exception as e:
                print(f"Accuracy Reward Function Error: {e}")


            rewards.append(reward)
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                os.makedirs(os.path.dirname(log_path), exist_ok=True)
                # local_rank = int(os.getenv("LOCAL_RANK", 0))
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Solution: {sol}\n")

        return rewards
    
    @staticmethod
    def aitz_verify_action_position(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime
        NUM_HISTORY = 4
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        # accelerator = kwargs['accelerator']
        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
            
        for idx, (content, sol) in enumerate(zip(student_answer, ground_truth)):
            reward = 0.0
            try:
                sol = Qwen2VLModule.pred2json(sol)
                pred_i = Qwen2VLModule.pred2json(content)
                action_pred = pred2json_post(pred_i)

                annot_position = kwargs['bbox'][idx]
                step_i = kwargs['step'][idx]
                action_ref = action2json(step_i)

                annot_position = np.array([annot_position[i:i + NUM_HISTORY]    \
                                            for i in range(0, len(annot_position), NUM_HISTORY)])
                check_match = check_actions_match(action_pred["touch_point"], 
                                                                    action_pred["lift_point"],
                                                                    action_pred["action_type"], 
                                                                    action_ref["touch_point"],
                                                                    action_ref["lift_point"], 
                                                                    action_ref["action_type"],
                                                                    annot_position)
                
                if sol['action'].lower() == 'click':
                    if check_match == True:
                        reward = 1.0
                    else:
                        reward = 0.0
                elif sol['action'].lower() == 'type' or sol['action'].lower() == 'select':
                    if (action_pred["typed_text"] == action_ref["typed_text"]) or (
                            action_pred["typed_text"] in action_ref["typed_text"]) or (
                            action_ref["typed_text"] in action_pred["typed_text"]):
                        reward = 1.0
                    else:
                        reward = 0.0
                else:
                    reward = 0.0
            except Exception as e:
                print(f"Accuracy Position Reward Function Error: {e}")

            rewards.append(reward)
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                os.makedirs(os.path.dirname(log_path), exist_ok=True)
                # local_rank = int(os.getenv("LOCAL_RANK", 0))
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Solution: {sol}\n")

        return rewards
    
    @staticmethod
    def guiact_verify_action_position(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime
        import ast
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        for i, (content, sol) in enumerate(zip(student_answer, ground_truth)):
            reward = 0.0
            try:
                action_pred = ast.literal_eval(content)
                answer = ast.literal_eval(sol)
                
                if action_pred["action"].lower() == answer["action"].lower():
                    if 'web' in kwargs['id'][i]:
                        if action_pred['action'].lower() in ['click', 'select', 'input', 'hover']:
                            click_point = action_pred["position"]
                            answer_point = answer["position"]
                            bbox_ref = kwargs['bbox'][i]
                            if bbox_ref is not None:
                                if (bbox_ref[0] <= click_point[0] <= bbox_ref[2]) and (bbox_ref[1] <= click_point[1] <= bbox_ref[3]):
                                    reward = 2.0
                            else:
                                if Qwen2VLModule.is_location_close(click_point, answer_point, threshold=0.05):
                                    reward = 2.0
                        elif action_pred['action'].lower() in ['answer', 'copy']:
                            click_value = action_pred["value"]
                            answer_value = answer["value"]
                            if click_value in answer_value or answer_value in click_value or click_value == answer_value:
                                reward = 2.0
                        elif action_pred['action'].lower() in ['scroll']:
                            click_value = action_pred["value"]
                            answer_value = answer["value"]
                            if click_value.lower() == answer_value.lower():
                                reward = 2.0
                        elif action_pred['action'].lower() in ['select_text']:
                            click_point = action_pred["position"]
                            answer_point = answer["position"]
                            bbox_ref = kwargs['bbox'][i]
                            if Qwen2VLModule.is_location_close(click_point[0], answer_point[0], threshold=0.05) and Qwen2VLModule.is_location_close(click_point[1], answer_point[1], threshold=0.05):
                                reward = 2.0
                    elif 'smartphone' in kwargs['id'][i]:
                        if action_pred['action'].lower() in ['tap']:
                            click_point = action_pred["position"]
                            answer_point = answer["position"]
                            bbox_ref = kwargs['bbox'][i]
                            if bbox_ref is not None:
                                if (bbox_ref[0] <= click_point[0] <= bbox_ref[2]) and (bbox_ref[1] <= click_point[1] <= bbox_ref[3]):
                                    reward = 2.0
                            else:
                                if Qwen2VLModule.is_location_close(click_point, answer_point, threshold=0.05):
                                    reward = 2.0
                        elif action_pred['action'].lower() in ['answer', 'input']:
                            click_value = action_pred["value"]
                            answer_value = answer["value"]
                            if click_value in answer_value or answer_value in click_value or click_value == answer_value:
                                reward = 2.0
                        elif action_pred['action'].lower() in ['swipe']:
                            click_point = action_pred["position"]
                            answer_point = answer["position"]
                            bbox_ref = kwargs['bbox'][i]
                            if Qwen2VLModule.is_location_close(click_point[0], answer_point[0], threshold=0.05) and Qwen2VLModule.is_location_close(click_point[1], answer_point[1], threshold=0.05):
                                reward = 2.0


            except Exception as e:
                print(f"Accuracy Reward Function Error: {e}")


            rewards.append(reward)
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                os.makedirs(os.path.dirname(log_path), exist_ok=True)
                # local_rank = int(os.getenv("LOCAL_RANK", 0))
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"Content: {completions[i][0]['content']}\n")
                    f.write(f"Solution: {sol}\n")

        return rewards
    
    @staticmethod
    def miniwob_verify_action_position(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime
        # try:
        ground_truth = [Qwen2VLModule.extract_answer(sol) for sol in solution]
        # Extract answer from content if it has think/answer tags
        student_answer = []
        for completion in completions:
            content = completion[0]["content"]  # 获取 content
            student_answer.append(Qwen2VLModule.extract_answer(content))

        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        for i, (content, sol) in enumerate(zip(student_answer, ground_truth)):
            reward = 0.0
            try:
                action_pred = Qwen2VLModule.pred2json(content)
                answer = Qwen2VLModule.pred2json(sol)
                click_point = action_pred["position"]
                answer_point = answer["position"]
                pred_value = action_pred["value"]
                answer_value = answer["value"]
                pred_action = action_pred["action"]
                answer_action = answer["action"]

                if answer_action.lower() == 'click' and pred_action.lower() == 'click':
                    if click_point is not None:
                        # bbox_ref = kwargs['bbox'][i]
                        # if (bbox_ref[0] <= click_point[0] <= bbox_ref[2]) and (bbox_ref[1] <= click_point[1] <= bbox_ref[3]):
                        #     reward = 3.0
                        if Qwen2VLModule.is_location_close(click_point, answer_point, threshold=0.04):
                            reward = 3.0
                elif answer_action.lower() == 'type' and pred_action.lower() == 'type':
                    if pred_value is not None:
                        if pred_value == answer_value:
                            reward = 3.0
                        

            except Exception as e:
                print(f"Position Accuracy Reward Function Error: {e}")


            rewards.append(reward)
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                os.makedirs(os.path.dirname(log_path), exist_ok=True)
                # local_rank = int(os.getenv("LOCAL_RANK", 0))
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"Content: {completions[i][0]['content']}\n")
                    f.write(f"Solution: {sol}\n")

        return rewards